/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    SGD.java
 *    Copyright (C) 2009-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;

import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/**
 * <!-- globalinfo-start --> Implements stochastic gradient descent for learning
 * various linear models (binary class SVM, binary class logistic regression,
 * squared loss, Huber loss and epsilon-insensitive loss linear regression).
 * Globally replaces all missing values and transforms nominal attributes into
 * binary ones. It also normalizes all attributes, so the coefficients in the
 * output are based on the normalized data.<br/>
 * For numeric class attributes, the squared, Huber or epsilon-insensitve loss
 * function must be used. Epsilon-insensitive and Huber loss may require a much
 * higher learning rate.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 *  -F
 *  Set the loss function to minimize.
 *  0 = hinge loss (SVM), 1 = log loss (logistic regression),
 *  2 = squared loss (regression), 3 = epsilon insensitive loss (regression),
 *  4 = Huber loss (regression).
 *  (default = 0)
 * </pre>
 * 
 * <pre>
 *  -L
 *  The learning rate. If normalization is
 *  turned off (as it is automatically for streaming data), then the
 *  default learning rate will need to be reduced (try 0.0001).
 *  (default = 0.01).
 * </pre>
 * 
 * <pre>
 *  -R &lt;double&gt;
 *  The lambda regularization constant (default = 0.0001)
 * </pre>
 * 
 * <pre>
 *  -E &lt;integer&gt;
 *  The number of epochs to perform (batch learning only, default = 500)
 * </pre>
 * 
 * <pre>
 *  -C &lt;double&gt;
 *  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
 * </pre>
 * 
 * <pre>
 *  -N
 *  Don't normalize the data
 * </pre>
 * 
 * <pre>
 *  -M
 *  Don't replace missing values
 * </pre>
 * 
 * <pre>
 *  -S &lt;num&gt;
 *  Random number seed.
 *  (default 1)
 * </pre>
 * 
 * <pre>
 *  -output-debug-info
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * </pre>
 * 
 * <pre>
 *  -do-not-check-capabilities
 *  If set, classifier capabilities are not checked before classifier is built
 *  (use with caution).
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz)
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision$
 * 
 */
public class SGD extends RandomizableClassifier implements UpdateableClassifier, OptionHandler, Aggregateable<SGD> {

    /** For serialization */
    private static final long serialVersionUID = -3732968666673530290L;

    /** Replace missing values */
    protected ReplaceMissingValues m_replaceMissing;

    /**
     * Convert nominal attributes to numerically coded binary ones. Uses supervised
     * NominalToBinary in the batch learning case
     */
    protected Filter m_nominalToBinary;

    /** Normalize the training data */
    protected Normalize m_normalize;

    /** The regularization parameter */
    protected double m_lambda = 0.0001;

    /** The learning rate */
    protected double m_learningRate = 0.01;

    /** Stores the weights (+ bias in the last element) */
    protected double[] m_weights;

    /** The epsilon parameter for epsilon insensitive and Huber loss */
    protected double m_epsilon = 1e-3;

    /** Holds the current iteration number */
    protected double m_t;

    /** The number of training instances */
    protected double m_numInstances;

    /**
     * The number of epochs to perform (batch learning). Total iterations is
     * m_epochs * num instances
     */
    protected int m_epochs = 500;

    /**
     * Turn off normalization of the input data. This option gets forced for
     * incremental training.
     */
    protected boolean m_dontNormalize = false;

    /**
     * Turn off global replacement of missing values. Missing values will be ignored
     * instead. This option gets forced for incremental training.
     */
    protected boolean m_dontReplaceMissing = false;

    /** Holds the header of the training data */
    protected Instances m_data;

    /**
     * Returns default capabilities of the classifier.
     * 
     * @return the capabilities of this classifier
     */
    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();

        // attributes
        result.enable(Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capability.MISSING_VALUES);

        // class
        if (m_loss == SQUAREDLOSS || m_loss == EPSILON_INSENSITIVE || m_loss == HUBER)
            result.enable(Capability.NUMERIC_CLASS);
        else
            result.enable(Capability.BINARY_CLASS);
        result.enable(Capability.MISSING_CLASS_VALUES);

        // instances
        result.setMinimumNumberInstances(0);

        return result;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String epsilonTipText() {
        return "The epsilon threshold for epsilon insensitive and Huber " + "loss. An error with absolute value less that this " + "threshold has loss of 0 for epsilon insensitive loss. " + "For Huber loss this is the boundary between the quadratic " + "and linear parts of the loss function.";
    }

    /**
     * Set the epsilon threshold on the error for epsilon insensitive and Huber loss
     * functions
     * 
     * @param e the value of epsilon to use
     */
    public void setEpsilon(double e) {
        m_epsilon = e;
    }

    /**
     * Get the epsilon threshold on the error for epsilon insensitive and Huber loss
     * functions
     * 
     * @return the value of epsilon to use
     */
    public double getEpsilon() {
        return m_epsilon;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    /**
     * Set the value of lambda to use
     * 
     * @param lambda the value of lambda to use
     */
    public void setLambda(double lambda) {
        m_lambda = lambda;
    }

    /**
     * Get the current value of lambda
     * 
     * @return the current value of lambda
     */
    public double getLambda() {
        return m_lambda;
    }

    /**
     * Set the learning rate.
     * 
     * @param lr the learning rate to use.
     */
    public void setLearningRate(double lr) {
        m_learningRate = lr;
    }

    /**
     * Get the learning rate.
     * 
     * @return the learning rate
     */
    public double getLearningRate() {
        return m_learningRate;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String learningRateTipText() {
        return "The learning rate. If normalization is turned off " + "(as it is automatically for streaming data), then" + "the default learning rate will need to be reduced (" + "try 0.0001).";
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). " + "The total number of iterations is epochs * num" + " instances.";
    }

    /**
     * Set the number of epochs to use
     * 
     * @param e the number of epochs to use
     */
    public void setEpochs(int e) {
        m_epochs = e;
    }

    /**
     * Get current number of epochs
     * 
     * @return the current number of epochs
     */
    public int getEpochs() {
        return m_epochs;
    }

    /**
     * Turn normalization off/on.
     * 
     * @param m true if normalization is to be disabled.
     */
    public void setDontNormalize(boolean m) {
        m_dontNormalize = m;
    }

    /**
     * Get whether normalization has been turned off.
     * 
     * @return true if normalization has been disabled.
     */
    public boolean getDontNormalize() {
        return m_dontNormalize;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String dontNormalizeTipText() {
        return "Turn normalization off";
    }

    /**
     * Turn global replacement of missing values off/on. If turned off, then missing
     * values are effectively ignored.
     * 
     * @param m true if global replacement of missing values is to be turned off.
     */
    public void setDontReplaceMissing(boolean m) {
        m_dontReplaceMissing = m;
    }

    /**
     * Get whether global replacement of missing values has been disabled.
     * 
     * @return true if global replacement of missing values has been turned off
     */
    public boolean getDontReplaceMissing() {
        return m_dontReplaceMissing;
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String dontReplaceMissingTipText() {
        return "Turn off global replacement of missing values";
    }

    /**
     * Set the loss function to use.
     * 
     * @param function the loss function to use.
     */
    public void setLossFunction(SelectedTag function) {
        if (function.getTags() == TAGS_SELECTION) {
            m_loss = function.getSelectedTag().getID();
        }
    }

    /**
     * Get the current loss function.
     * 
     * @return the current loss function.
     */
    public SelectedTag getLossFunction() {
        return new SelectedTag(m_loss, TAGS_SELECTION);
    }

    /**
     * Returns the tip text for this property
     * 
     * @return tip text for this property suitable for displaying in the
     *         explorer/experimenter gui
     */
    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM), " + "log loss (logistic regression) or " + "squared loss (regression).";
    }

    /**
     * Returns an enumeration describing the available options.
     * 
     * @return an enumeration of all the available options.
     */
    @Override
    public Enumeration<Option> listOptions() {

        Vector<Option> newVector = new Vector<Option>();
        newVector.add(new Option("\tSet the loss function to minimize.\n\t0 = " + "hinge loss (SVM), 1 = log loss (logistic regression),\n\t" + "2 = squared loss (regression), 3 = epsilon insensitive loss (regression)," + "\n\t4 = Huber loss (regression).\n\t(default = 0)", "F", 1, "-F"));
        newVector.add(new Option("\tThe learning rate. If normalization is\n" + "\tturned off (as it is automatically for streaming data), then the\n\t" + "default learning rate will need to be reduced " + "(try 0.0001).\n\t(default = 0.01).", "L", 1, "-L"));
        newVector.add(new Option("\tThe lambda regularization constant " + "(default = 0.0001)", "R", 1, "-R <double>"));
        newVector.add(new Option("\tThe number of epochs to perform (" + "batch learning only, default = 500)", "E", 1, "-E <integer>"));
        newVector.add(new Option("\tThe epsilon threshold (" + "epsilon-insenstive and Huber loss only, default = 1e-3)", "C", 1, "-C <double>"));
        newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
        newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));

        newVector.addAll(Collections.list(super.listOptions()));

        return newVector.elements();
    }

    /**
     * 
     * Parses a given list of options.
     * <p/>
     * 
     * <!-- options-start --> Valid options are:
     * <p/>
     * 
     * <pre>
     *  -F
     *  Set the loss function to minimize.
     *  0 = hinge loss (SVM), 1 = log loss (logistic regression),
     *  2 = squared loss (regression), 3 = epsilon insensitive loss (regression),
     *  4 = Huber loss (regression).
     *  (default = 0)
     * </pre>
     * 
     * <pre>
     *  -L
     *  The learning rate. If normalization is
     *  turned off (as it is automatically for streaming data), then the
     *  default learning rate will need to be reduced (try 0.0001).
     *  (default = 0.01).
     * </pre>
     * 
     * <pre>
     *  -R &lt;double&gt;
     *  The lambda regularization constant (default = 0.0001)
     * </pre>
     * 
     * <pre>
     *  -E &lt;integer&gt;
     *  The number of epochs to perform (batch learning only, default = 500)
     * </pre>
     * 
     * <pre>
     *  -C &lt;double&gt;
     *  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
     * </pre>
     * 
     * <pre>
     *  -N
     *  Don't normalize the data
     * </pre>
     * 
     * <pre>
     *  -M
     *  Don't replace missing values
     * </pre>
     * 
     * <pre>
     *  -S &lt;num&gt;
     *  Random number seed.
     *  (default 1)
     * </pre>
     * 
     * <pre>
     *  -output-debug-info
     *  If set, classifier is run in debug mode and
     *  may output additional info to the console
     * </pre>
     * 
     * <pre>
     *  -do-not-check-capabilities
     *  If set, classifier capabilities are not checked before classifier is built
     *  (use with caution).
     * </pre>
     * 
     * <!-- options-end -->
     * 
     * @param options the list of options as an array of strings
     * @throws Exception if an option is not supported
     */
    @Override
    public void setOptions(String[] options) throws Exception {
        reset();

        super.setOptions(options);

        String lossString = Utils.getOption('F', options);
        if (lossString.length() != 0) {
            setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION));
        }

        String lambdaString = Utils.getOption('R', options);
        if (lambdaString.length() > 0) {
            setLambda(Double.parseDouble(lambdaString));
        }

        String learningRateString = Utils.getOption('L', options);
        if (learningRateString.length() > 0) {
            setLearningRate(Double.parseDouble(learningRateString));
        }

        String epochsString = Utils.getOption("E", options);
        if (epochsString.length() > 0) {
            setEpochs(Integer.parseInt(epochsString));
        }

        String epsilonString = Utils.getOption("C", options);
        if (epsilonString.length() > 0) {
            setEpsilon(Double.parseDouble(epsilonString));
        }

        setDontNormalize(Utils.getFlag("N", options));
        setDontReplaceMissing(Utils.getFlag('M', options));

        Utils.checkForRemainingOptions(options);
    }

    /**
     * Gets the current settings of the classifier.
     * 
     * @return an array of strings suitable for passing to setOptions
     */
    @Override
    public String[] getOptions() {
        ArrayList<String> options = new ArrayList<String>();

        options.add("-F");
        options.add("" + getLossFunction().getSelectedTag().getID());
        options.add("-L");
        options.add("" + getLearningRate());
        options.add("-R");
        options.add("" + getLambda());
        options.add("-E");
        options.add("" + getEpochs());
        options.add("-C");
        options.add("" + getEpsilon());
        if (getDontNormalize()) {
            options.add("-N");
        }
        if (getDontReplaceMissing()) {
            options.add("-M");
        }

        Collections.addAll(options, super.getOptions());

        return options.toArray(new String[1]);
    }

    /**
     * Returns a string describing classifier
     * 
     * @return a description suitable for displaying in the explorer/experimenter
     *         gui
     */
    public String globalInfo() {
        return "Implements stochastic gradient descent for learning" + " various linear models (binary class SVM, binary class" + " logistic regression, squared loss, Huber loss and " + "epsilon-insensitive loss linear regression)." + " Globally replaces all missing values and transforms nominal" + " attributes into binary ones. It also normalizes all attributes," + " so the coefficients in the output are based on the normalized" + " data.\n" + "For numeric class attributes, the squared, Huber or " + "epsilon-insensitve loss function must be used. Epsilon-insensitive " + "and Huber loss may require a much higher learning rate.";
    }

    /**
     * Reset the classifier.
     */
    public void reset() {
        m_t = 1;
        m_weights = null;
    }

    /**
     * Method for building the classifier.
     * 
     * @param data the set of training instances.
     * @throws Exception if the classifier can't be built successfully.
     */
    @Override
    public void buildClassifier(Instances data) throws Exception {
        reset();

        // can classifier handle the data?
        getCapabilities().testWithFail(data);

        data = new Instances(data);
        data.deleteWithMissingClass();

        if (data.numInstances() > 0 && !m_dontReplaceMissing) {
            m_replaceMissing = new ReplaceMissingValues();
            m_replaceMissing.setInputFormat(data);
            data = Filter.useFilter(data, m_replaceMissing);
        }

        // check for only numeric attributes
        boolean onlyNumeric = true;
        for (int i = 0; i < data.numAttributes(); i++) {
            if (i != data.classIndex()) {
                if (!data.attribute(i).isNumeric()) {
                    onlyNumeric = false;
                    break;
                }
            }
        }

        if (!onlyNumeric) {
            if (data.numInstances() > 0) {
                m_nominalToBinary = new weka.filters.supervised.attribute.NominalToBinary();
            } else {
                m_nominalToBinary = new weka.filters.unsupervised.attribute.NominalToBinary();
            }
            m_nominalToBinary.setInputFormat(data);
            data = Filter.useFilter(data, m_nominalToBinary);
        }

        if (!m_dontNormalize && data.numInstances() > 0) {

            m_normalize = new Normalize();
            m_normalize.setInputFormat(data);
            data = Filter.useFilter(data, m_normalize);
        }

        m_numInstances = data.numInstances();

        m_weights = new double[data.numAttributes() + 1];
        m_data = new Instances(data, 0);

        if (data.numInstances() > 0) {
            data.randomize(new Random(getSeed())); // randomize the data
            train(data);
        }
    }

    /** the hinge loss function. */
    public static final int HINGE = 0;

    /** the log loss function. */
    public static final int LOGLOSS = 1;

    /** the squared loss function. */
    public static final int SQUAREDLOSS = 2;

    /** The epsilon insensitive loss function */
    public static final int EPSILON_INSENSITIVE = 3;

    /** The Huber loss function */
    public static final int HUBER = 4;

    /** The current loss function to minimize */
    protected int m_loss = HINGE;

    /** Loss functions to choose from */
    public static final Tag[] TAGS_SELECTION = { new Tag(HINGE, "Hinge loss (SVM)"), new Tag(LOGLOSS, "Log loss (logistic regression)"), new Tag(SQUAREDLOSS, "Squared loss (regression)"), new Tag(EPSILON_INSENSITIVE, "Epsilon-insensitive loss (SVM regression)"), new Tag(HUBER, "Huber loss (robust regression)") };

    protected double dloss(double z) {
        if (m_loss == HINGE) {
            return (z < 1) ? 1 : 0;
        }

        if (m_loss == LOGLOSS) {
            // log loss
            if (z < 0) {
                return 1.0 / (Math.exp(z) + 1.0);
            } else {
                double t = Math.exp(-z);
                return t / (t + 1);
            }
        }

        if (m_loss == EPSILON_INSENSITIVE) {
            if (z > m_epsilon) {
                return 1.0;
            }

            if (-z > m_epsilon) {
                return -1.0;
            }

            return 0;
        }

        if (m_loss == HUBER) {
            if (Math.abs(z) <= m_epsilon) {
                return z;
            } else if (z > 0.0) {
                return m_epsilon;
            } else {
                return -m_epsilon;
            }
        }

        // squared loss
        return z;
    }

    private void train(Instances data) throws Exception {
        for (int e = 0; e < m_epochs; e++) {
            for (int i = 0; i < data.numInstances(); i++) {
                updateClassifier(data.instance(i), false);
            }
        }
    }

    protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
        double result = 0;

        int n1 = inst1.numValues();
        int n2 = weights.length - 1;

        for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
            int ind1 = inst1.index(p1);
            int ind2 = p2;
            if (ind1 == ind2) {
                if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
                    result += inst1.valueSparse(p1) * weights[p2];
                }
                p1++;
                p2++;
            } else if (ind1 > ind2) {
                p2++;
            } else {
                p1++;
            }
        }
        return (result);
    }

    /**
     * Updates the classifier with the given instance.
     * 
     * @param instance the new training instance to include in the model
     * @param filter   true if the instance should pass through any of the filters
     *                 set up in buildClassifier(). When batch training
     *                 buildClassifier() already batch filters all training
     *                 instances so don't need to filter them again here.
     * @exception Exception if the instance could not be incorporated in the model.
     */
    protected void updateClassifier(Instance instance, boolean filter) throws Exception {

        if (!instance.classIsMissing()) {
            if (filter) {
                if (m_replaceMissing != null) {
                    m_replaceMissing.input(instance);
                    instance = m_replaceMissing.output();
                }

                if (m_nominalToBinary != null) {
                    m_nominalToBinary.input(instance);
                    instance = m_nominalToBinary.output();
                }

                if (m_normalize != null) {
                    m_normalize.input(instance);
                    instance = m_normalize.output();
                }
            }

            double wx = dotProd(instance, m_weights, instance.classIndex());

            double y;
            double z;
            if (instance.classAttribute().isNominal()) {
                y = (instance.classValue() == 0) ? -1 : 1;
                z = y * (wx + m_weights[m_weights.length - 1]);
            } else {
                y = instance.classValue();
                z = y - (wx + m_weights[m_weights.length - 1]);
                y = 1;
            }

            // Compute multiplier for weight decay
            double multiplier = 1.0;
            if (m_numInstances == 0) {
                multiplier = 1.0 - (m_learningRate * m_lambda) / m_t;
            } else {
                multiplier = 1.0 - (m_learningRate * m_lambda) / m_numInstances;
            }
            for (int i = 0; i < m_weights.length - 1; i++) {
                m_weights[i] *= multiplier;
            }

            // Only need to do the following if the loss is non-zero
            // if (m_loss != HINGE || (z < 1)) {
            if (m_loss == SQUAREDLOSS || m_loss == LOGLOSS || m_loss == HUBER || (m_loss == HINGE && (z < 1)) || (m_loss == EPSILON_INSENSITIVE && Math.abs(z) > m_epsilon)) {

                // Compute Factor for updates
                double factor = m_learningRate * y * dloss(z);

                // Update coefficients for attributes
                int n1 = instance.numValues();
                for (int p1 = 0; p1 < n1; p1++) {
                    int indS = instance.index(p1);
                    if (indS != instance.classIndex() && !instance.isMissingSparse(p1)) {
                        m_weights[indS] += factor * instance.valueSparse(p1);
                    }
                }

                // update the bias
                m_weights[m_weights.length - 1] += factor;
            }
            m_t++;
        }
    }

    /**
     * Updates the classifier with the given instance.
     * 
     * @param instance the new training instance to include in the model
     * @exception Exception if the instance could not be incorporated in the model.
     */
    @Override
    public void updateClassifier(Instance instance) throws Exception {
        updateClassifier(instance, true);
    }

    /**
     * Computes the distribution for a given instance
     * 
     * @param inst the instance for which distribution is computed
     * @return the distribution
     * @throws Exception if the distribution can't be computed successfully
     */
    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        double[] result = (inst.classAttribute().isNominal()) ? new double[2] : new double[1];

        if (m_replaceMissing != null) {
            m_replaceMissing.input(inst);
            inst = m_replaceMissing.output();
        }

        if (m_nominalToBinary != null) {
            m_nominalToBinary.input(inst);
            inst = m_nominalToBinary.output();
        }

        if (m_normalize != null) {
            m_normalize.input(inst);
            inst = m_normalize.output();
        }

        double wx = dotProd(inst, m_weights, inst.classIndex());// * m_wScale;
        double z = (wx + m_weights[m_weights.length - 1]);

        if (inst.classAttribute().isNumeric()) {
            result[0] = z;
            return result;
        }

        if (z <= 0) {
            // z = 0;
            if (m_loss == LOGLOSS) {
                result[0] = 1.0 / (1.0 + Math.exp(z));
                result[1] = 1.0 - result[0];
            } else {
                result[0] = 1;
            }
        } else {
            if (m_loss == LOGLOSS) {
                result[1] = 1.0 / (1.0 + Math.exp(-z));
                result[0] = 1.0 - result[1];
            } else {
                result[1] = 1;
            }
        }
        return result;
    }

    public double[] getWeights() {
        return m_weights;
    }

    /**
     * Prints out the classifier.
     * 
     * @return a description of the classifier as a string
     */
    @Override
    public String toString() {
        if (m_weights == null) {
            return "SGD: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (m_loss == HINGE) {
            buff.append("Hinge loss (SVM)\n\n");
        } else if (m_loss == LOGLOSS) {
            buff.append("Log loss (logistic regression)\n\n");
        } else if (m_loss == EPSILON_INSENSITIVE) {
            buff.append("Epsilon insensitive loss (SVM regression)\n\n");
        } else if (m_loss == HUBER) {
            buff.append("Huber loss (robust regression)\n\n");
        } else {
            buff.append("Squared loss (linear regression)\n\n");
        }

        buff.append(m_data.classAttribute().name() + " = \n\n");
        int printed = 0;

        for (int i = 0; i < m_weights.length - 1; i++) {
            if (i != m_data.classIndex()) {
                if (printed > 0) {
                    buff.append(" + ");
                } else {
                    buff.append("   ");
                }

                buff.append(Utils.doubleToString(m_weights[i], 12, 4) + " " + ((m_normalize != null) ? "(normalized) " : "") + m_data.attribute(i).name() + "\n");

                printed++;
            }
        }

        if (m_weights[m_weights.length - 1] > 0) {
            buff.append(" + " + Utils.doubleToString(m_weights[m_weights.length - 1], 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-m_weights[m_weights.length - 1], 12, 4));
        }

        return buff.toString();
    }

    protected int m_numModels = 0;

    /**
     * Aggregate an object with this one
     * 
     * @param toAggregate the object to aggregate
     * @return the result of aggregation
     * @throws Exception if the supplied object can't be aggregated for some reason
     */
    @Override
    public SGD aggregate(SGD toAggregate) throws Exception {

        if (m_weights == null) {
            throw new Exception("No model built yet, can't aggregate");
        }

        if (!m_data.equalHeaders(toAggregate.m_data)) {
            throw new Exception("Can't aggregate - data headers dont match: " + m_data.equalHeadersMsg(toAggregate.m_data));
        }

        if (m_weights.length != toAggregate.getWeights().length) {
            throw new Exception("Can't aggregate - SDG to aggregate has weight vector " + "that differs in length from ours.");
        }

        for (int i = 0; i < m_weights.length; i++) {
            m_weights[i] += toAggregate.getWeights()[i];
        }

        m_numModels++;

        return this;
    }

    /**
     * Call to complete the aggregation process. Allows implementers to do any final
     * processing based on how many objects were aggregated.
     * 
     * @throws Exception if the aggregation can't be finalized for some reason
     */
    @Override
    public void finalizeAggregation() throws Exception {
        if (m_numModels == 0) {
            throw new Exception("Unable to finalize aggregation - " + "haven't seen any models to aggregate");
        }

        for (int i = 0; i < m_weights.length; i++) {
            m_weights[i] /= (m_numModels + 1); // plus one for us
        }

        // aggregation complete
        m_numModels = 0;
    }

    /**
     * Main method for testing this class.
     */
    public static void main(String[] args) {
        runClassifier(new SGD(), args);
    }
}
